"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""
# Utility functions for creating explanations

import sklearn
import numpy as np
import pandas as pd
import sklearn.datasets
import sklearn.ensemble
import itertools
import random
import json,os
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from lime.lime_text import IndexedString
import scipy
from collections import namedtuple
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from collections import namedtuple
from itertools import combinations, product, chain
from sklearn import linear_model
from sklearn.utils import check_random_state
import errno
import logging

from threadpoolctl import threadpool_limits

def compute_categorical_freqs(df, categorical_feature_inds):
    """ Compute frequencies of the categorical variables in the dataset.

    Arguments:
        df (pandas DataFrame) - data frame with all the variables
        categorical_feature_inds (list) - list of categorical features in
            the data

    Returns:
        cat_freqs (dict) - Key-value pairs of categorical features and indices

    """
    cat_freqs = {}
    for idx in categorical_feature_inds:
        (ks, vs) = np.unique(df.iloc[:, idx],
                             return_index=False, 
                             return_inverse=False, 
                             return_counts=True)
        vs = vs/sum(vs)

        cat_freqs[idx] = (ks, vs)
    
    return cat_freqs

def train_perturbation(df, categorical_feature_names, cond_prob_train=False):
    """Train a perturbation model - collect data stats and optionally estimate
    a conditional probability predictor

    Arguments:
        df (pandas DataFrame) - data frame with all the variables
        categorical_feature_names (dict) - names of the categorical features in
            the data

    Returns:
        scaler_mean (array) - Mean of the features (will be 0 for cat features)
        scaler_std (array) - Std. dev. of the features (will be 1 for cat features)
        cat_freqs (dict) - Key-value pairs of categorical features and indices
        cond_prob_predictor (dict) - Key-value pairs of conditional probability 
            predictor for categorical features
    """

    scaler = StandardScaler().fit(df)
    feature_names = list(df.columns)
    categorical_feature_inds = [feature_names.index(f) 
                                for f  in categorical_feature_names]
    scaler.mean_[categorical_feature_inds] = 0.0
    scaler.var_[categorical_feature_inds] = 1.0

    # Categorical frequencies
    cat_freqs = compute_categorical_freqs(df, categorical_feature_inds)
    
    # Conditional probability predictor
    if cond_prob_train:
        cond_prob_predictor = compute_conditional_prob_predictor(
                                            df, categorical_feature_inds)
    else:
        cond_prob_predictor = None

    return scaler.mean_, np.sqrt(scaler.var_), cat_freqs, cond_prob_predictor

def compute_conditional_prob_predictor(df, categorical_feature_inds):
    """Estimate conditional probability predictor for categorical features

    Arguments:
        df (pandas DataFrame) - data frame with all the variables
        categorical_feature_inds (list) - Indices of the categorical features in
            the data

    Returns:
        cond_prob_predictor (dict) - Key-value pairs of conditional probability 
            predictor for categorical features 
    """    
    from sklearn.linear_model import LogisticRegressionCV
    
    col_names = df.columns
    cond_prob_predictor = {}

    for c_ind in categorical_feature_inds:

        col_name = col_names[c_ind]
        y_sel = df.iloc[:, c_ind]
        y_sel = np.array(y_sel.values)
        X_sel = scipy.sparse.csc_matrix(pd.get_dummies(df.drop(columns=col_name), 
                                                prefix_sep="=", sparse=True, 
                                                dtype=np.float64).values)

        lrf = LogisticRegressionCV(penalty="l2", max_iter=1000, n_jobs=-1)
        lrf.fit(X_sel, y_sel)
        classes = lrf.classes_
        print(col_name, classes, lrf.score(X_sel, y_sel))

        cond_prob_predictor[c_ind] = (classes, lrf)

    return cond_prob_predictor

def numerical_perturbation(samp, feat_mean, feat_std, 
                           numerical_feature_inds,
                           num_perturbations,
                           random_state,
                           sample_around_instance=True):
    """Perturbations for numerical features

    Arguments:
        samp (pandas DataFrame) - data frame (1 row) with the sample to be 
            perturbed
        feat_mean (array) - mean of the features
        feat_std (array) - std. dev. of the features
        numerical_feature_inds (array) - indices of the numerical features
        num_perturbations (int) - number of perturbations to be created
        random state (RandomState) - Random state used to create perturbations
        sample_around_instance (bool) - sample around the instance (True) or around
            the mean of the dataset (False)

    Returns:
        samp_pert (pandas DataFrame) - Perturbed data with numerical features 
            perturbed. First row is the original data point.
    """    

    samp_pert = pd.concat([samp]*num_perturbations, ignore_index=True)
    if sample_around_instance:
        samp_pert.iloc[:,  numerical_feature_inds] =\
            samp_pert.iloc[:,  numerical_feature_inds] +\
            feat_std[numerical_feature_inds]*\
            random_state.randn(num_perturbations, len(numerical_feature_inds))
    else:
        samp_pert.iloc[:,  numerical_feature_inds] =\
            feat_mean[numerical_feature_inds] +\
            feat_std[numerical_feature_inds]*\
            random_state.randn(num_perturbations, len(numerical_feature_inds))
    samp_pert.iloc[0:1, numerical_feature_inds] =\
            samp.iloc[0:1, numerical_feature_inds].values
        
    return samp_pert

def categorical_perturbation(samp_pert, cat_freqs,
                             cond_prob_predictor,
                             categorical_feature_inds,
                             num_perturbations,
                             random_state,
                             categorical_sampling="basic",
                             cat_feats_to_perturb=1,
                             bias_category=0.0):
    """Perturbations for categorical features

    Arguments:
        samp_pert (pandas DataFrame) - data frame with the examples to be perturbed.
            This is the output of the function `numerical_perturbation`.
        cat_freqs (dict) - categorical ids and frequencies
        cond_prob_predictor (dict) - Key-value pairs of conditional probability 
            predictor for categorical features 
        categorical_feature_inds (list) - indices of categorical features
        num_perturbations (int) - number of perturbations to be created
        random state (RandomState) - Random state used to create perturbations
        categorical_sampling (str) - "basic" or "enhanced" categorical sampling
        cat_feats_to_perturb (int) - number of cat features to perturb
        bias_category (float) - bias to the category

    Returns:
        samp_pert (pandas DataFrame) - Perturbed data with categorical features 
            perturbed. First row is the original data point.
    """
    # Column names
    col_names = samp_pert.columns
    
    # For each categorical column create a perturbation
    num_cat_features = len(cat_freqs)
    cat_freqs_keys = list(cat_freqs.keys())
    num_pert = num_perturbations-1
    if cat_feats_to_perturb < num_cat_features:
        mask_mat = np.random.choice([False,True], 
                         size=(num_pert,num_cat_features),
                         p=[1.0 - cat_feats_to_perturb/num_cat_features, 
                            cat_feats_to_perturb/num_cat_features])
    else:
        mask_mat = np.ones((num_pert,num_cat_features),
                           dtype=np.bool)


    samp_pert_all = np.zeros((num_pert, num_cat_features))
    
    for (cat_ind, (vals, freqs)) in cat_freqs.items():
        col_name = col_names[cat_ind]
        cat_val_base = samp_pert.iat[0, cat_ind]

        if categorical_sampling == "enhanced":
            samp_pert_sel = scipy.sparse.csc_matrix(
                        pd.get_dummies(samp_pert.iloc[0:1, :].drop(columns=col_name),
                        prefix_sep="=", sparse=True, 
                        dtype=np.float64).values)
            prob_vals = cond_prob_predictor[cat_ind][1].predict_proba(
                                samp_pert_sel)[0]
            classes = cond_prob_predictor[cat_ind][1].classes_

            if bias_category > 0.0:
                base_idx = np.where(classes == cat_val_base)[0][0]
                prob_vals[base_idx] = prob_vals[base_idx] +  bias_category
                prob_vals = prob_vals/np.sum(prob_vals)
        else:
            classes = vals
            prob_vals = freqs
        
        sample_cats = random_state.choice(classes, size=num_pert, 
                        replace=True, p=prob_vals)

        samp_pert_all[:, cat_freqs_keys.index(cat_ind)] =\
                            sample_cats

    
#     # Restrict cat perturbations
    samp_pert_small = samp_pert.iloc[1:, cat_freqs_keys].values
    samp_pert_small[mask_mat] = samp_pert_all[mask_mat]
    samp_pert.iloc[1:, cat_freqs_keys] = samp_pert_small

    # Binary (dummy-coded) categorical variables
    samp_pert_binary = pd.get_dummies(samp_pert, prefix_sep="=", dtype=np.float64)
    samp_pert_binary.iloc[1:, categorical_feature_inds] =\
            samp_pert_binary.iloc[0:1, categorical_feature_inds].values *\
            samp_pert_binary.iloc[1:, categorical_feature_inds].values

    return samp_pert, samp_pert_binary

def create_perturbation(s1, feat_mean, feat_std, cat_freqs,
                        cond_prob_predictor,
                        numerical_feature_inds,
                        categorical_feature_inds,
                        num_perturbations,
                        random_state,
                        categorical_sampling,
                        cat_feats_to_perturb=1,
                        bias_category=0.0,
                        sample_around_instance=False):
    """Perturbations for numerical and categorical features together

    Arguments:
        s1 (pandas DataFrame) - data frame (1 row) with the sample to be 
            perturbed
        feat_mean (array) - mean of the features
        feat_std (array) - std. dev. of the features
        cat_freqs (dict) - categorical ids and frequencies
        numerical_feature_inds (array) - indices of the numerical features
        categorical_feature_inds (list) - indices of categorical features
        num_perturbations (int) - number of perturbations to be created
        random state (RandomState) - Random state used to create perturbations
        categorical_sampling (str) - "basic" or "enhanced" categorical sampling
        cat_feats_to_perturb (int) - number of cat features to perturb
        bias_category (float) - bias to the category
        sample_around_instance (bool) - sample around the instance (True) or around
            the mean of the dataset (False)

    Returns:
        samp_pert (pandas DataFrame) - Perturbed data with numerical and categorical
            features perturbed. First row is the original data point.
        samp_pert_binary (pandas DataFrame) - Same as `samp_pert` except categorical
            features are binary and not categorical
    """
    s1_pert = numerical_perturbation(s1, feat_mean, feat_std, 
                                numerical_feature_inds,
                                num_perturbations,
                                random_state,
                                sample_around_instance=False)

    if len(categorical_feature_inds) and cat_feats_to_perturb > 0:
        s1_pert, s1_pert_binary = categorical_perturbation(s1_pert, cat_freqs, 
                                    cond_prob_predictor,
                                    categorical_feature_inds,
                                    num_perturbations, random_state,
                                    categorical_sampling,
                                    cat_feats_to_perturb=cat_feats_to_perturb,
                                    bias_category=bias_category)
    else:
        s1_pert_binary  = s1_pert.copy()

    return s1_pert, s1_pert_binary


def create_multiple_perturbations(s1, 
                        random_seeds,
                        perturb_config):
    """Multiple Perturbations for numerical and categorical features together

    Arguments:
        s1 (pandas DataFrame) - data frame (1 row) with the sample to be 
            perturbed
        random_seeds (list) - list of random seeds to create perturbations for
        perturb_config (dict) - kwargs to be passed for `create_perturbation`

    Returns:
        samp_perts (list) - pandas DataFrames with perturbed data
        samp_perts_binary (list) - pandas DataFrames with perturbed data and binary-valued
            features in place of categorical ones
    """
    samp_perts = []
    samp_perts_binary = []

    for seed in random_seeds:
        perturb_config["random_state"] = check_random_state(seed)
        samp_pert, samp_pert_binary = create_perturbation(s1,
                                            **perturb_config)
        samp_perts.append(samp_pert)
        samp_perts_binary.append(samp_pert_binary)

    return samp_perts, samp_perts_binary

def create_multiple_perturbations_bootstrap(s1, 
                        random_seeds,
                        perturb_config):
    """Multiple Perturbations for numerical and categorical features together
    A single perturbation is created and multiple boostrap samples are drawen

    Arguments:
        s1 (pandas DataFrame) - data frame (1 row) with the sample to be 
            perturbed
        random_seeds (list) - list of random seeds to create perturbations for
        perturb_config (dict) - kwargs to be passed for `create_perturbation`

    Returns:
        samp_perts (list) - pandas DataFrames with perturbed data
        samp_perts_binary (list) - pandas DataFrames with perturbed data and binary-valued
            features in place of categorical ones
    """
    samp_perts = []
    samp_perts_binary = []

    base_random_seed = int(np.mean(random_seeds))
    perturb_config["random_state"] = check_random_state(base_random_seed)
    samp_pert, samp_pert_binary = create_perturbation(s1,
                                        **perturb_config)

    # bootstrap samples are created here
    for seed in random_seeds:
        samp_inds = np.zeros(perturb_config["num_perturbations"], dtype=int)
        samp_inds[1:perturb_config["num_perturbations"]] =\
                    np.random.choice(np.arange(1, 
                        perturb_config["num_perturbations"]),
                        size=perturb_config["num_perturbations"]-1)

        samp_perts.append(samp_pert.iloc[samp_inds, :])
        samp_perts_binary.append(samp_pert_binary.iloc[samp_inds, :])

    return samp_pert, samp_pert_binary, samp_perts, samp_perts_binary

def scale_data(samp_pert_binary, 
                        numerical_feature_inds, feat_mean, feat_std):
    """Center data to mean and scale data by std.

    Arguments:
        samp_pert_binary (pandas DataFrame) - Perturbed data. First row is 
            the original data point
        numerical_feature_inds (array) - indices of the numerical features
        feat_mean (array) - mean of the features
        feat_std (array) - std. dev. of the features

    Returns:
        samp_pert_binary_scaled (np ndarray) - Scaled data
    """
    samp_pert_binary_scaled = samp_pert_binary.copy()
    samp_pert_binary_scaled.iloc[:, numerical_feature_inds] = (
                samp_pert_binary.iloc[:, numerical_feature_inds] -
                feat_mean[numerical_feature_inds])/feat_std[numerical_feature_inds]
    return samp_pert_binary_scaled.values

def compute_weights(S1, 
                    distance_metric="euclidean", 
                    kernel_width=0.75,
                    normalize=False):
    """Weights between the first (original) instance and rest of (perturbed) instances
    in the data matrix.

    Arguments:
        S1 (pandas DataFrame) - data with original and perturbed instances
        distance_metric (str) - distance metric
        kernel_width (float) - width of Gaussian kernel
        normalize (bool) - whether to normalize weights

    Returns:
        weights (np ndarray) - Weights computed from distances
    """
    p = S1.shape[1]
    
    # Distances
    d1 = sklearn.metrics.pairwise_distances(
                    S1,
                    S1[0].reshape(1, -1),
                    metric=distance_metric
            ).ravel()

    # kw = np.sqrt(p) * kernel_width
    kw = kernel_width

    def kernel(d, kernel_width = kw):
        return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

    weights = kernel(d1)

    if normalize:
        weights = weights/np.sum(weights)

    return weights

def create_text_perturbations(indexed_string, num_samples, 
                              combined_vocab, random_state):
    """Create text perturbations

    Arguments:
        indexed string (lime indexed string) - Indexed string used in LIME
        num_samples (int) - number of perturbations
        combined_vocab (list) - vocabulary of the sentence
        random_state (RandomState) - numpy random state

    Returns:
        perturbed_data (pandas DataFrame) - bag of words perturbed data
        inverse_data (list) - list of perturbed strings
    """

    doc_size = indexed_string.num_words()
    sample = random_state.randint(1, doc_size + 1, num_samples - 1)
    data = np.ones((num_samples, doc_size))
    data[0] = np.ones(doc_size)
    features_range = range(doc_size)
    inverse_data = [indexed_string.raw_string()]
    for i, size in enumerate(sample, start=1):
        inactive = random_state.choice(features_range, size,
                                            replace=False)
        data[i, inactive] = 0
        inverse_data.append(indexed_string.inverse_removing(inactive))

    comb_data_df = pd.DataFrame(
            data=np.zeros((num_samples, len(combined_vocab))),
            columns=combined_vocab)
    comb_data_df.loc[:, indexed_string.inverse_vocab] = data

    return comb_data_df.values, inverse_data


def lime_explanation(S12, y12, w, num_nonzeros=5, debias=False):
    """Linear (LIME) explanations

    Arguments:
        S12 (np ndarray) - data matrix
        y12 (np ndarray) - response matrix
        w (np ndarray) - weights
        num_nonzeros (int) - number of non-zeros in the explanation
        debias (bool) - compute least squares using the chosen features

    Returns:
        perturbed_data (pandas DataFrame) - bag of words perturbed data
        inverse_data (list) - list of perturbed strings
    """
    p = S12.shape[1]
    if num_nonzeros >= p:
        num_nonzeros = p
        coef = np.ones(p)
        debias = True
    else:
        coef_path = linear_model.lars_path(
                        np.sqrt(w).reshape(-1,1)*S12, 
                        np.sqrt(w)*y12, method="lasso", 
                        positive=False,
                        max_iter=num_nonzeros)[2]
        
        if num_nonzeros < coef_path.shape[1]:
            coef = coef_path[:, num_nonzeros]
        else:
            num_nonzeros = coef_path.shape[1]  - 1
            coef = coef_path[:, num_nonzeros]
            # print("number of non-zeros requested adjusted to %d" % (num_nonzeros))

    if debias:
        chosen_features = np.abs(coef) > 1e-6
        coef_deb = np.zeros(S12.shape[1])
        lr = linear_model.LinearRegression(fit_intercept=False)
        lr.fit(np.sqrt(w).reshape(-1,1)*S12[:, chosen_features],
               np.sqrt(w)*y12)
        coef_deb[chosen_features] = lr.coef_
    else:
        coef_deb = coef

    return coef_deb

def create_fname_suffix(fname_config):
    from itertools import chain

    if type(fname_config) == list:
        # flatten list
        fname_config_flat = list(chain(*list(fname_config)))
    elif type(fname_config) == dict:
        fname_config_flat = list(chain(*list(fname_config.items())))
    
    return "_".join([shorten_string(str(i))
                     for i in fname_config_flat])

def shorten_string(s):
    ssplt = s.split("_")
    if len(ssplt) == 1:
        smod = ssplt[0]
    else:
        smod = "".join([s[0] for s in ssplt])
    return smod

def fname_data(config, dataset_name):
    fname_list = ([(dataset_name, )] +
             list(config["Data_Preproc"][dataset_name].items()))
    return create_fname_suffix(fname_list)

def fname_model(config, model_name, dataset_name):
    fname_list = ([(model_name, )] +
             list(config["Bb_Model"][model_name].items()) +
             [(dataset_name, )] +
             list(config["Data_Preproc"][dataset_name].items()))
    return create_fname_suffix(fname_list)

def fname_base_perts(config, base_pert_name, dataset_name):
    fname_list = ([(base_pert_name, )] +
             list(config[base_pert_name].items()) +
             [(dataset_name, )] +
             list(config["Data_Preproc"][dataset_name].items()))
    return create_fname_suffix(fname_list)

def fname_preds(config, base_pert_name, model_name, dataset_name):
    fname_list = ([("Preds", )] + 
             list(config["Preds"].items()) +
             [(base_pert_name, )] +
             list(config[base_pert_name].items()) +
             [(model_name, )] +
             list(config["Bb_Model"][model_name].items()) +
             [(dataset_name, )] +
             list(config["Data_Preproc"][dataset_name].items()))
    return create_fname_suffix(fname_list)

def fname_env_perts(config, env_pert_name, 
                base_pert_name, dataset_name):
    fname_list = ([(env_pert_name, )] +
             list(config[env_pert_name].items()) +
             [(base_pert_name, )] +
             list(config[base_pert_name].items()) +
             [(dataset_name, )] +
             list(config["Data_Preproc"][dataset_name].items()))
    return create_fname_suffix(fname_list)

def fname_exp(config, exp_name, env_pert_name, 
                    base_pert_name, model_name, dataset_name):
    fname_list = ([(exp_name, )] +
             list(config[exp_name].items()) +
             [("Preds", )] + 
             list(config["Preds"].items()) +
             [(env_pert_name, )] +
             list(config[env_pert_name].items()) +
             [(base_pert_name, )] +
             list(config[base_pert_name].items()) +
             [(model_name, )] +
             list(config["Bb_Model"][model_name].items()) +
             [(dataset_name, )] +
             list(config["Data_Preproc"][dataset_name].items()))
    return create_fname_suffix(fname_list)


def create_dir_if_not_exist(dirname):
    
    if not os.path.exists(dirname):
        try:
            os.makedirs(dirname)
        except OSError as exc: # Guard against race condition
            if exc.errno == errno.EEXIST and os.path.isdir(dirname):
                pass
            else:
                raise
        logging.info('Created directory %s', dirname)
    else:
        logging.info('Directory %s already exists', dirname)

    return 1